{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 128,
   "id": "408e4b71-7896-4bb7-a22f-169055950115",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: http://mirrors.cloud.aliyuncs.com/pypi/simple/\n",
      "Collecting openpyxl\n",
      "  Downloading http://mirrors.cloud.aliyuncs.com/pypi/packages/6a/94/a59521de836ef0da54aaf50da6c4da8fb4072fb3053fa71f052fd9399e7a/openpyxl-3.1.2-py2.py3-none-any.whl (249 kB)\n",
      "\u001b[K     |████████████████████████████████| 249 kB 110.7 MB/s eta 0:00:01\n",
      "\u001b[?25hCollecting et-xmlfile\n",
      "  Downloading http://mirrors.cloud.aliyuncs.com/pypi/packages/96/c2/3dd434b0108730014f1b96fd286040dc3bcb70066346f7e01ec2ac95865f/et_xmlfile-1.1.0-py3-none-any.whl (4.7 kB)\n",
      "Installing collected packages: et-xmlfile, openpyxl\n",
      "Successfully installed et-xmlfile-1.1.0 openpyxl-3.1.2\n"
     ]
    }
   ],
   "source": [
    "!pip install openpyxl"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db5d3b4a-43af-43ec-afd1-545e15febfe8",
   "metadata": {},
   "source": [
    "# 加载1.8B Chat Int4模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "id": "51b962d5-3b98-408b-b132-2d5e73e6fa23",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-01-06 21:06:27,395 - modelscope - WARNING - Model revision not specified, use revision: v1.0.0\n",
      "Try importing flash-attention for faster inference...\n"
     ]
    }
   ],
   "source": [
    "from modelscope import snapshot_download\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "model_dir = snapshot_download('qwen/Qwen-1_8B-Chat-Int4')\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_dir,\n",
    "    device_map=\"auto\",\n",
    "    trust_remote_code=True\n",
    ").eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "020af6b3-022e-4cad-967c-e8144e71780a",
   "metadata": {},
   "source": [
    "# 提示词工程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "id": "4cbb0900-6554-4751-bbc8-326fac486cbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 城市数据\n",
    "with open('city.txt','r',encoding='utf-8') as fp:\n",
    "    city_list=fp.readlines()\n",
    "    city_list=[line.strip().split(' ')[1] for line in city_list]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "id": "b6338a20-6c9c-49fa-9f5b-9c606d747260",
   "metadata": {},
   "outputs": [],
   "source": [
    "Q='青岛4月6日下雨么?'\n",
    "\n",
    "prompt_template='''\n",
    "给定一句话：“%s”，请你按步骤要求工作。\n",
    "\n",
    "步骤1：识别这句话中的城市和日期共2个信息\n",
    "步骤2：根据城市和日期信息，生成JSON字符串，格式为{\"city\":城市,\"date\":日期}\n",
    "\n",
    "请问，这个JSON字符串是：\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65138cce-9cd6-4dcf-8d61-025a13864f6c",
   "metadata": {},
   "source": [
    "# 生成SFT微调数据"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d44d76f5-0008-42a1-a459-0d9d54d23ce6",
   "metadata": {},
   "source": [
    "Qwen的SFT数据格式要求:\n",
    "```\n",
    "[\n",
    "  {\n",
    "    \"id\": \"identity_0\",\n",
    "    \"conversations\": [\n",
    "      {\n",
    "        \"from\": \"user\",\n",
    "        \"value\": \"你好\"\n",
    "      },\n",
    "      {\n",
    "        \"from\": \"assistant\",\n",
    "        \"value\": \"我是一个语言模型，我叫通义千问。\"\n",
    "      }\n",
    "    ]\n",
    "  }\n",
    "]\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "id": "7a9673d1-3359-41ef-b851-a4be82765701",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import json\n",
    "import time \n",
    "\n",
    "Q_arr=[]\n",
    "A_arr=[]\n",
    "\n",
    "Q_list=[\n",
    "    ('{city}{year}年{month}月{day}日的天气','%Y-%m-%d'),\n",
    "    ('{city}{year}年{month}月{day}号的天气','%Y-%m-%d'),\n",
    "    ('{city}{month}月{day}日的天气','%m-%d'),\n",
    "    ('{city}{month}月{day}号的天气','%m-%d'),\n",
    "\n",
    "    ('{year}年{month}月{day}日{city}的天气','%Y-%m-%d'),\n",
    "    ('{year}年{month}月{day}号{city}的天气','%Y-%m-%d'),\n",
    "    ('{month}月{day}日{city}的天气','%m-%d'),\n",
    "    ('{month}月{day}号{city}的天气','%m-%d'),\n",
    "\n",
    "    ('你们{year}年{month}月{day}日去{city}玩吗？','%Y-%m-%d'),\n",
    "    ('你们{year}年{month}月{day}号去{city}玩么？','%Y-%m-%d'),\n",
    "    ('你们{month}月{day}日去{city}玩吗？','%m-%d'),\n",
    "    ('你们{month}月{day}号去{city}玩吗？','%m-%d'),\n",
    "]\n",
    "\n",
    "# 生成一批\"1月2号\"、\"1月2日\"、\"2023年1月2号\", \"2023年1月2日\", \"2023-02-02\", \"03-02\"之类的话术, 教会它做日期转换\n",
    "for i in range(1000):\n",
    "    Q=Q_list[random.randint(0,len(Q_list)-1)]\n",
    "    city=city_list[random.randint(0,len(city_list)-1)]\n",
    "    year=random.randint(1990,2025)\n",
    "    month=random.randint(1,12)\n",
    "    day=random.randint(1,28)\n",
    "    time_str='{}-{}-{}'.format(year,month,day)\n",
    "    date_field=time.strftime(Q[1],time.strptime(time_str,'%Y-%m-%d'))\n",
    "    Q=Q[0].format(city=city,year=year,month=month,day=day) # 问题\n",
    "    A=json.dumps({'city':city,'date':date_field},ensure_ascii=False)  # 回答\n",
    "\n",
    "    Q_arr.append(prompt_template%(Q,))\n",
    "    A_arr.append(A)\n",
    "\n",
    "import pandas as pd \n",
    "\n",
    "df=pd.DataFrame({'Prompt':Q_arr,'Completion':A_arr})\n",
    "df.to_excel('train.xlsx')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df864230-32ac-4e01-ab5a-83a6758c773f",
   "metadata": {},
   "source": [
    "# 微调模型，生成到output_qwen\n",
    "\n",
    "bash finetune/finetune_qlora_single_gpu.sh  -m /root/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat-Int4 -d /root/Qwen/train.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "799b107b-4c01-4853-bb51-ca1552ab3314",
   "metadata": {},
   "source": [
    "# 加载SFT后的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "id": "60a4a4f0-32ef-4254-a3f1-31648e580fa6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Try importing flash-attention for faster inference...\n"
     ]
    }
   ],
   "source": [
    "from peft import AutoPeftModelForCausalLM\n",
    "\n",
    "model = AutoPeftModelForCausalLM.from_pretrained(\n",
    "    'output_qwen', # path to the output directory\n",
    "    device_map=\"auto\",\n",
    "    trust_remote_code=True\n",
    ").eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 164,
   "id": "5a053c4a-d8b1-44e9-b4f0-0906f9115333",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Q:2020年4月16号三亚下雨么？\n",
      "A:{\"city\":三亚,\"date\":2020-04-16}\n",
      "\n",
      "Q:青岛3-15号天气预报\n",
      "A:{\"city\": \"青岛\",\"date\": \"03-15\"}\n",
      "\n",
      "Q:5月6号下雪么，城市是威海\n",
      "A:{\"city\":威海,\"date\":05-06}\n",
      "\n",
      "Q:青岛2023年12月30号有雾霾么?\n",
      "A:{\"city\": \"青岛\",\"date\": \"2023-12-30\"}\n",
      "\n",
      "Q:我打算6月1号去北京旅游，请问天气怎么样？\n",
      "A:{\"city\": \"北京\",\"date\": \"06-01\"}\n",
      "\n",
      "Q:你们打算1月3号坐哪一趟航班去上海？\n",
      "A:{\"city\": \"上海\",\"date\": \"01-03\"}\n",
      "\n",
      "Q:小明和小红是8月8号在上海结婚么?\n",
      "A:{\"city\": \"上海\",\"date\": \"08-08\"}\n",
      "\n",
      "Q:一起去东北看冰雕么，大概是1月15号左右，我们3个人一起\n",
      "A:{\"city\": \"东北\", \"date\": \"01-15\"}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model.generation_config.top_p=0 # 只选择概率最高的token\n",
    "\n",
    "Q_list=['2020年4月16号三亚下雨么？','青岛3-15号天气预报','5月6号下雪么，城市是威海','青岛2023年12月30号有雾霾么?','我打算6月1号去北京旅游，请问天气怎么样？','你们打算1月3号坐哪一趟航班去上海？','小明和小红是8月8号在上海结婚么?',\n",
    "        '一起去东北看冰雕么，大概是1月15号左右，我们3个人一起']\n",
    "for Q in Q_list:\n",
    "    prompt=prompt_template%(Q,)\n",
    "    A,hist=model.chat(tokenizer,prompt,history=None)\n",
    "    print('Q:%s\\nA:%s\\n'%(Q,A))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39e860c1-4572-456d-aedc-703f192f599a",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt='青岛海边钓鱼需要特别注意什么？'\n",
    "resp,hist=model.chat(tokenizer,prompt,history=None)\n",
    "print(resp)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
